clc;
clearvars;
load Data;
format short g


%=========== Define simulation parameters ===========%
max_iter = 100;
market_size = 100;
cluster = 1;
text_cluster = num2str(cluster);
text_size = num2str(market_size);
text_iter = num2str(max_iter);


%=========== Define variables to store results ===========%
shareM = NaN(max_iter,13);
shareF = NaN(max_iter,13);
DivorceCosts = NaN(max_iter,12);
HHIndices = cell(max_iter,1);
StabilityIndices = cell(max_iter,1);


if cluster == 0
    text_solver = 'gurobi';
    parnum = 2;
else
    text_solver = 'mosek';
    parnum = 20;
    addpath(genpath("/opt/apps/easybuild/software/YALMIP/R20230609/"))
    addpath(genpath("~/mosek/10.1/toolbox/r2017a"))
end


%=========== Define chores = housework + childcare hours ===========%
Data.chorehm = Data.hworkm + Data.chcarem;
Data.chorehf = Data.hworkf + Data.chcaref;
Data.chorehm(Data.singlefemale == 1) = 0;
Data.chorehf(Data.singlemale == 1) = 0;


%=========== Define povertyline ===========%
Data.consumption = Data.q + Data.Q;
Data.percapitaconsumption(find(Data.couple == 1)) = Data.consumption(find(Data.couple == 1))/2;
Data.percapitaconsumption(find(Data.couple == 0)) = Data.consumption(find(Data.couple == 0));
povertyline = 0.6*median(Data.percapitaconsumption);


%=========== Define age restrictions for marriage market ===========%
age_diff = Data.agem-Data.agef;
age_diff = age_diff(Data.couple==1);
lb = prctile(age_diff,2.5);
ub = prctile(age_diff,97.5);


%===========  Define age, education and children categories for preference types ===========%
for i = 1:size(Data,1)
    if isnan(Data.agem(i))
        Data.agecatm(i) = 0;
    elseif 1 < Data.agem(i) &&  Data.agem(i) <= 35
        Data.agecatm(i) = 1;
    elseif 35 < Data.agem(i) &&  Data.agem(i) <= 50
        Data.agecatm(i) = 2;
    elseif 50 < Data.agem(i)
        Data.agecatm(i) = 3;
    end

    if isnan(Data.agef(i))
        Data.agecatf(i) = 0;
    elseif 1 < Data.agef(i) &&  Data.agef(i) <= 35
        Data.agecatf(i) = 1;
    elseif 35 < Data.agef(i) &&  Data.agef(i) <= 50
        Data.agecatf(i) = 2;
    elseif 50 < Data.agef(i)
        Data.agecatf(i) = 3;
    end

    if isnan(Data.edum(i))
        Data.educatm(i) = 0;
    else
        Data.educatm(i) = Data.edum(i);
    end

    if isnan(Data.eduf(i))
        Data.educatf(i) = 0;
    else
        Data.educatf(i) = Data.eduf(i);
    end

    if isnan(Data.nchild(i))
        Data.kidscat(i) = 0;
    elseif Data.nchild(i) == 0
        Data.kidscat(i) = 1;
    elseif Data.nchild(i) > 0
        Data.kidscat(i) = 2;
    end

    Data.typem(i) = Data.educatm(i)*100 + Data.kidscat(i)*10 + Data.agecatm(i);
    Data.typef(i) = Data.educatf(i)*100 + Data.kidscat(i)*10 + Data.agecatf(i);

    if Data.singlemale(i) == 1
        Data.typef(i) = 0;
    elseif Data.singlefemale(i) == 1
        Data.typem(i) = 0;
    end
end


%=========== Define household types and weights ===========%
Data.types = Data.typem.*1000 + Data.typef;
[C,~,ic] = unique(Data.types);
a_counts = accumarray(ic,1);
value_counts = [C, a_counts];
value_counts(:,3) = value_counts(:,2)./sum(value_counts(:,2));


%=========== Define couple type based on education and employment status of spouses ===========%
for i = 1:size(Data,1)
    if isnan(Data.wagem(i))
        Data.empcatm(i) = 1;
    else
        if Data.fulltimem(i) ~= 1
            Data.empcatm(i) = 2;
        elseif Data.fulltimem(i) == 1
            Data.empcatm(i) = 3;
        end
    end

    if isnan(Data.wagef(i))
        Data.empcatf(i) = 1;
    else
        if Data.fulltimef(i) ~= 1
            Data.empcatf(i) = 2;
        elseif Data.fulltimef(i) == 1
            Data.empcatf(i) = 3;
        end
    end

    if Data.singlefemale(i) == 1
        Data.empcatm(i) = 0;
    elseif Data.singlemale(i) == 1
        Data.empcatf(i) = 0;
    end
end

%========== Check numbers
%[C,~,ic] = unique(Data.empcatf);
%a_counts = accumarray(ic,1);
%value_counts = [C, a_counts];


for i = 1:size(Data,1)
    if Data.couple(i) == 1
        Data.hcatm(i) = 10*(Data.empcatm(i)) + Data.educatm(i);
        Data.hcatf(i) = 10*(Data.empcatf(i)) + Data.educatf(i);
        Data.hcat(i) = Data.hcatm(i)*100 + Data.hcatf(i);
    elseif Data.singlemale(i) == 1
        Data.hcat(i) = 0;
    elseif Data.singlefemale(i) == 1
        Data.hcat(i) = 0;
    end
end



%%
%=========== Estimate the model with subsamples ===========% 
sc = parallel.pool.Constant(RandStream('Threefry'));


parfor (iter = 1:max_iter,parnum)

    stream = sc.Value;
    stream.Substream = iter;

    complete = 0;
    while complete == 0

        %=========== randomly draw household types from their weighted distribution ===========% 
        rand_index = randsample(stream,value_counts(:,1), market_size, true, value_counts(:,3));
        [C,~,ic] = unique(rand_index);
        a_counts = accumarray(ic,1);
        index_counts = [C, a_counts];

        %=========== Given the number of household types required, randomly draw households of that type from the observed households ===========%
        Data_r = [];
        for i = 1:size(index_counts,1)
            Extract = Data(find(Data.types == index_counts(i,1)),:);
            Extract_index = randsample(stream,size(Extract,1),index_counts(i,2),true);
            Data_r = [Data_r; Extract(Extract_index,:)];
        end

       %=========== Define consideration set ===========% 
       consideration_set = get_consideration_sets(Data_r.couple,Data_r.singlemale,Data_r.singlefemale,Data_r.agem,Data_r.agef,lb,ub);
       
       %=========== Compute stability indices ===========% 
        fprintf('Computing Divorce Costs...\n');
        [temp_DivorceCostsM,temp_DivorceCostsF,temp_DivorceCostsMF,solved] = get_indices(Data_r.couple,Data_r.singlemale,Data_r.singlefemale,...
            Data_r.leisurem,Data_r.leisuref,Data_r.chorehm,Data_r.chorehf,Data_r.wagem,Data_r.wagef,Data_r.q,Data_r.Q,...
            Data_r.nonlabor,consideration_set,text_solver);

        %=========== Compute RICEBs ===========% 
        if solved == 0
            fprintf('Computing RICEBs...\n');
            [temp_min_male,temp_max_male,temp_min_female,temp_max_female,problem] = get_ricebs(Data_r.couple,Data_r.singlemale,Data_r.singlefemale,...
                Data_r.leisurem,Data_r.leisuref,Data_r.chorehm,Data_r.chorehf,Data_r.wagem,Data_r.wagef,Data_r.q,Data_r.Q,...
                Data_r.nonlabor,consideration_set,temp_DivorceCostsM+eps,temp_DivorceCostsF+eps,temp_DivorceCostsMF+eps,Data_r.hcat,text_solver);

            if sum(~isnan(temp_min_male(:))) && sum(~isnan(temp_max_male(:))) && sum(~isnan(temp_min_female(:))) && sum(~isnan(temp_max_female(:))) && problem == 0

                %=========== Store stability indices ===========%     
                ncouple = length(find(Data_r.couple));
                temp_DivorceCostsMF(~consideration_set) = NaN;
                DC_temp = temp_DivorceCostsMF;

                DivorceCosts_maxm = max(DC_temp,[],2);
                DivorceCosts_maxm = DivorceCosts_maxm(find(Data_r.couple));
                DivorceCosts_maxf = max(DC_temp)';
                DivorceCosts_maxf = DivorceCosts_maxf(find(Data_r.couple));
                DivorceCosts_avgf = (nansum(DC_temp)')./(sum(~isnan(DC_temp))');
                DivorceCosts_avgf = DivorceCosts_avgf(find(Data_r.couple));
                DivorceCosts_avgm = (nansum(DC_temp,2))./(sum(~isnan(DC_temp),2));
                DivorceCosts_avgm = DivorceCosts_avgm(find(Data_r.couple));

                DivorceCosts_temp = zeros(ncouple,9);
                DivorceCosts_temp(:,1) = (temp_DivorceCostsM(find(Data_r.couple)))./(Data_r.consumption(find(Data_r.couple)));
                DivorceCosts_temp(:,2) = (temp_DivorceCostsF(find(Data_r.couple)))./(Data_r.consumption(find(Data_r.couple)));
                DivorceCosts_temp(:,3) = DivorceCosts_maxm./(Data_r.consumption(find(Data_r.couple)));
                DivorceCosts_temp(:,4) = DivorceCosts_maxf./(Data_r.consumption(find(Data_r.couple)));
                DivorceCosts_temp(:,5) = DivorceCosts_avgm./(Data_r.consumption(find(Data_r.couple)));
                DivorceCosts_temp(:,6) = DivorceCosts_avgf./(Data_r.consumption(find(Data_r.couple)));
                DivorceCosts_temp(:,7) = iter;
                mmsize = nansum(consideration_set,2);
                DivorceCosts_temp(:,8) = mmsize(find(Data_r.couple));
                mmsize = nansum(consideration_set)';
                DivorceCosts_temp(:,9) = mmsize(find(Data_r.couple));

                NBPStable = DivorceCosts_temp(:,5) <= 0.00001 & DivorceCosts_temp(:,6) <= 0.00001;
                IRMStable = DivorceCosts_temp(:,1) <= 0.00001;
                IRFStable = DivorceCosts_temp(:,2) <= 0.00001;                
                DivorceCosts(iter,:) = [mean(DivorceCosts_temp), mean(NBPStable), mean(IRMStable), mean(IRFStable)];

                temp_DivorceCostsM(find(~Data_r.couple)) = NaN;
                temp_DivorceCostsF(find(~Data_r.couple)) = NaN;
                DC_temp = [DC_temp, temp_DivorceCostsM];
                DC_temp = [DC_temp; [temp_DivorceCostsF',0]];
                StabilityIndices{iter} = DC_temp;
                HHIndices{iter} = Data_r.familyid;

                %=========== Store RICEBs ===========%  
                male_temp = zeros(1,2*size(temp_min_male,1)+1);
                for i = 1:size(temp_min_male,1)
                    male_temp(1,2*i-1) = temp_min_male(i);
                    male_temp(1,2*i) = temp_max_male(i);
                end
                male_temp(1,2*size(temp_min_male,1)+1) = iter;
                shareM(iter,:) = male_temp;

                female_temp = zeros(1,2*size(temp_min_female,1)+1);
                for i = 1:size(temp_min_female,1)
                    female_temp(1,2*i-1) = temp_min_female(i);
                    female_temp(1,2*i) = temp_max_female(i);
                end
                female_temp(1,2*size(temp_min_female,1)+1) = iter;
                shareF(iter,:) = female_temp;

                complete = 1;
                fprintf('*********** Iteration %d finished *************\n', iter);
            end

        end
    end
end

filename = ['DataPSID_results_ricebs_4cat_size',text_size,'_iter',text_iter,'_Cluster',text_cluster];
save(filename)


%============== TABLE 24 =======================%
fprintf('Bounds for male employed = no and education = low = [%.4f, %.4f] \n',             mean(shareM(:,1), "omitnan"), mean(shareM(:,2), "omitnan"));
fprintf('Bounds for male employed = no and education = high = [%.4f, %.4f] \n',            mean(shareM(:,3), "omitnan"), mean(shareM(:,4), "omitnan"));
fprintf('Bounds for male employed = yes, part-time and education = low = [%.4f, %.4f] \n', mean(shareM(:,5), "omitnan"), mean(shareM(:,6), "omitnan"));
fprintf('Bounds for male employed = yes, part-time and education = high = [%.4f, %.4f] \n',mean(shareM(:,7), "omitnan"), mean(shareM(:,8), "omitnan"));
fprintf('Bounds for male employed = yes, full-time and education = low = [%.4f, %.4f] \n', mean(shareM(:,9), "omitnan"), mean(shareM(:,10), "omitnan"));
fprintf('Bounds for male employed = yes, full-time and education = high = [%.4f, %.4f] \n',mean(shareM(:,11), "omitnan"),mean(shareM(:,12), "omitnan"));

fprintf('Bounds for female employed = no and education = low = [%.4f, %.4f] \n',             mean(shareF(:,1), "omitnan"), mean(shareF(:,2), "omitnan"));
fprintf('Bounds for female employed = no and education = high = [%.4f, %.4f] \n',            mean(shareF(:,3), "omitnan"), mean(shareF(:,4), "omitnan"));
fprintf('Bounds for female employed = yes, part-time and education = low = [%.4f, %.4f] \n', mean(shareF(:,5), "omitnan"), mean(shareF(:,6), "omitnan"));
fprintf('Bounds for female employed = yes, part-time and education = high = [%.4f, %.4f] \n',mean(shareF(:,7), "omitnan"), mean(shareF(:,8), "omitnan"));
fprintf('Bounds for female employed = yes, full-time and education = low = [%.4f, %.4f] \n', mean(shareF(:,9), "omitnan"), mean(shareF(:,10), "omitnan"));
fprintf('Bounds for female employed = yes, full-time and education = high = [%.4f, %.4f] \n',mean(shareF(:,11), "omitnan"),mean(shareF(:,12), "omitnan"));


